"""
Phase 4: Monetary Transmission and Aging (Local Projections)
Does aging weaken monetary transmission? Test via local projections of
cumulative output growth and cumulative inflation on delta_policy_rate,
interacted with Z_1.

Key hypothesis: Z_1_x_delta_rate > 0 on cum_growth means aging weakens
the negative output effect of rate hikes.

Outputs:
  - Table 5: Growth response (phase4_table5_transmission_growth.md)
  - Table 5b: Investment channel (phase4_table5b_transmission_investment.md)
  - Table 5c: Consumption channel (phase4_table5c_transmission_consumption.md)
  - Table 6: Inflation response (phase4_table6_transmission_inflation.md)
  - IRF data: phase4_irf_data.csv
  - IRF figure: phase4_irf_old_vs_young.png
"""

import sys
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

# ── paths ────────────────────────────────────────────────────────────────────
PROJECT = Path("/mnt/c/demographics_capital_flows")
sys.path.insert(0, str(PROJECT / "multilateral" / "src"))
from model import PanelGLS

DATA_PATH = PROJECT / "monetary" / "data" / "processed" / "monetary_panel.csv"
TABLE_DIR = PROJECT / "monetary" / "output" / "tables"
FIG_DIR = PROJECT / "monetary" / "output" / "figures"
TABLE_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]

HORIZONS = [1, 2, 3, 4, 5]


# ── helpers ──────────────────────────────────────────────────────────────────
def run_model(panel, dep_var, rhs_vars):
    """Run PanelGLS; return dict with coefs/se/pvals/r2/nobs or None."""
    cols = [dep_var] + rhs_vars + ["iso3", "year"]
    df = panel[cols].dropna()
    if len(df) < 50:
        print(f"  SKIP {dep_var}: only {len(df)} obs after dropna")
        return None
    try:
        y = df[dep_var].values
        X = df[rhs_vars].values
        gls = PanelGLS()
        gls.fit(y, X, df["iso3"].values, df["year"].values)
        return {
            "coefs": dict(zip(rhs_vars, gls.beta)),
            "se": dict(zip(rhs_vars, gls.se)),
            "pvals": dict(zip(rhs_vars, gls.pvalues)),
            "r2": gls.r_squared,
            "nobs": gls.n_obs,
            "ncountries": gls.n_countries,
        }
    except Exception as e:
        print(f"  ERROR {dep_var}: {e}")
        return None


def star(p):
    """Significance stars."""
    if p < 0.01:
        return "***"
    elif p < 0.05:
        return "**"
    elif p < 0.10:
        return "*"
    return ""


def fmt_coef(res, var):
    """Format coefficient as 'coef(se)***'."""
    if res is None or var not in res["coefs"]:
        return ""
    c = res["coefs"][var]
    s = res["se"][var]
    p = res["pvals"][var]
    return f"{c:.3f}{star(p)} ({s:.3f})"


def build_markdown_table(title, row_vars, col_labels, results, footer_rows):
    """Build a pipe-format markdown table."""
    lines = [f"# {title}\n"]
    header = "| Variable | " + " | ".join(col_labels) + " |"
    sep = "|" + "---|" * (len(col_labels) + 1)
    lines.append(header)
    lines.append(sep)
    for var in row_vars:
        cells = [fmt_coef(r, var) for r in results]
        lines.append(f"| {var} | " + " | ".join(cells) + " |")
    lines.append(sep)
    for label, vals in footer_rows:
        lines.append(f"| {label} | " + " | ".join(vals) + " |")
    return "\n".join(lines)


def construct_local_projection_vars(panel):
    """
    Construct cumulative growth and cumulative inflation at horizons h=1..5,
    plus the interaction Z_1_x_delta_rate.
    """
    df = panel.sort_values(["iso3", "year"]).copy()

    # Construct delta_policy_rate if not present
    if "delta_policy_rate" not in df.columns:
        df["delta_policy_rate"] = df.groupby("iso3")["real_policy_rate"].diff()
        print("  Constructed delta_policy_rate from real_policy_rate differences")

    # Construct Z_1_x_delta_rate interaction
    if "Z_1_x_delta_rate" not in df.columns:
        df["Z_1_x_delta_rate"] = df["Z_1"] * df["delta_policy_rate"]
        print("  Constructed Z_1_x_delta_rate = Z_1 * delta_policy_rate")

    # Construct cumulative growth and inflation at each horizon
    for h in HORIZONS:
        col_g = f"cum_growth_{h}"
        col_i = f"cum_inflation_{h}"
        if col_g not in df.columns:
            df[col_g] = df.groupby("iso3")["rgdp_growth"].transform(
                lambda x: x.rolling(window=h, min_periods=h).sum().shift(-h + 1)
            )
        if col_i not in df.columns:
            df[col_i] = df.groupby("iso3")["inflation"].transform(
                lambda x: x.rolling(window=h, min_periods=h).sum().shift(-h + 1)
            )

    return df


# ── main ─────────────────────────────────────────────────────────────────────
def main():
    print("=" * 70)
    print("PHASE 4: MONETARY TRANSMISSION — LOCAL PROJECTIONS")
    print("=" * 70)

    panel = pd.read_csv(DATA_PATH)
    print(f"Panel: {len(panel):,} obs, {panel['iso3'].nunique()} countries")

    panel = construct_local_projection_vars(panel)

    oecd = panel[panel["iso3"].isin(OECD_38)].copy()
    print(f"OECD:  {len(oecd):,} obs, {oecd['iso3'].nunique()} countries")

    # ── TABLE 5: Growth Transmission ─────────────────────────────────────
    print("\n── Table 5: Growth Response to Monetary Shocks ──")

    z_vars = ["Z_1", "Z_2", "Z_3"]
    controls_growth = ["rgdp_growth", "inflation", "fiscal_bal_gdp", "kaopen",
                       "nfa_gdp_lag"]
    rhs_growth = ["delta_policy_rate"] + z_vars + ["Z_1_x_delta_rate"] + controls_growth
    row_vars_growth = rhs_growth

    results_t5 = []
    col_labels_t5 = []

    for h in HORIZONS:
        dep_var = f"cum_growth_{h}"
        # Full sample
        res_full = run_model(panel, dep_var, rhs_growth)
        results_t5.append(res_full)
        col_labels_t5.append(f"h={h} Full")
        if res_full:
            c_int = res_full["coefs"].get("Z_1_x_delta_rate", np.nan)
            p_int = res_full["pvals"].get("Z_1_x_delta_rate", np.nan)
            print(f"  h={h} Full: Z_1_x_delta_rate = {c_int:.3f} (p={p_int:.3f})")

        # OECD
        res_oecd = run_model(oecd, dep_var, rhs_growth)
        results_t5.append(res_oecd)
        col_labels_t5.append(f"h={h} OECD")
        if res_oecd:
            c_int = res_oecd["coefs"].get("Z_1_x_delta_rate", np.nan)
            p_int = res_oecd["pvals"].get("Z_1_x_delta_rate", np.nan)
            print(f"  h={h} OECD: Z_1_x_delta_rate = {c_int:.3f} (p={p_int:.3f})")

    footer_t5 = [
        ("R²", [f"{r['r2']:.3f}" if r else "" for r in results_t5]),
        ("N obs", [f"{r['nobs']:,}" if r else "" for r in results_t5]),
        ("N countries", [f"{r['ncountries']}" if r else "" for r in results_t5]),
    ]

    md_t5 = build_markdown_table(
        "Table 5: Growth Response to Monetary Shocks (Local Projections)",
        row_vars_growth, col_labels_t5, results_t5, footer_t5,
    )
    print("\n" + md_t5)
    t5_path = TABLE_DIR / "phase4_table5_transmission_growth.md"
    t5_path.write_text(md_t5)
    print(f"\nSaved: {t5_path}")

    # ── TABLE 6: Inflation Transmission ──────────────────────────────────
    print("\n── Table 6: Inflation Response to Monetary Shocks ──")

    controls_inflation = ["rgdp_growth", "output_gap", "fiscal_bal_gdp", "kaopen",
                          "nfa_gdp_lag"]
    rhs_inflation = ["delta_policy_rate"] + z_vars + ["Z_1_x_delta_rate"] + controls_inflation
    row_vars_inflation = rhs_inflation

    results_t6 = []
    col_labels_t6 = []

    for h in HORIZONS:
        dep_var = f"cum_inflation_{h}"
        # Full sample
        res_full = run_model(panel, dep_var, rhs_inflation)
        results_t6.append(res_full)
        col_labels_t6.append(f"h={h} Full")
        if res_full:
            c_int = res_full["coefs"].get("Z_1_x_delta_rate", np.nan)
            p_int = res_full["pvals"].get("Z_1_x_delta_rate", np.nan)
            print(f"  h={h} Full: Z_1_x_delta_rate = {c_int:.3f} (p={p_int:.3f})")

        # OECD
        res_oecd = run_model(oecd, dep_var, rhs_inflation)
        results_t6.append(res_oecd)
        col_labels_t6.append(f"h={h} OECD")
        if res_oecd:
            c_int = res_oecd["coefs"].get("Z_1_x_delta_rate", np.nan)
            p_int = res_oecd["pvals"].get("Z_1_x_delta_rate", np.nan)
            print(f"  h={h} OECD: Z_1_x_delta_rate = {c_int:.3f} (p={p_int:.3f})")

    footer_t6 = [
        ("R²", [f"{r['r2']:.3f}" if r else "" for r in results_t6]),
        ("N obs", [f"{r['nobs']:,}" if r else "" for r in results_t6]),
        ("N countries", [f"{r['ncountries']}" if r else "" for r in results_t6]),
    ]

    md_t6 = build_markdown_table(
        "Table 6: Inflation Response to Monetary Shocks (Local Projections)",
        row_vars_inflation, col_labels_t6, results_t6, footer_t6,
    )
    print("\n" + md_t6)
    t6_path = TABLE_DIR / "phase4_table6_transmission_inflation.md"
    t6_path.write_text(md_t6)
    print(f"\nSaved: {t6_path}")

    # ── IRF: Old vs Young Quintile ───────────────────────────────────────
    print("\n── IRF: Old vs Young OECD Quintile Comparison ──")

    # Split OECD into Z_1 quintiles by country median
    z1_median = oecd.groupby("iso3")["Z_1"].median()
    q20 = z1_median.quantile(0.20)
    q80 = z1_median.quantile(0.80)

    young_isos = z1_median[z1_median <= q20].index.tolist()
    old_isos = z1_median[z1_median >= q80].index.tolist()

    young_panel = oecd[oecd["iso3"].isin(young_isos)].copy()
    old_panel = oecd[oecd["iso3"].isin(old_isos)].copy()

    print(f"  Young quintile: {len(young_isos)} countries ({', '.join(sorted(young_isos))})")
    print(f"  Old quintile:   {len(old_isos)} countries ({', '.join(sorted(old_isos))})")

    # Simple specification for IRF: cum_growth_h ~ delta_policy_rate + controls
    controls_irf = ["rgdp_growth", "inflation", "fiscal_bal_gdp", "kaopen", "nfa_gdp_lag"]
    rhs_irf = ["delta_policy_rate"] + controls_irf

    irf_records = []
    for h in HORIZONS:
        dep_var = f"cum_growth_{h}"
        for group_label, group_panel in [("young", young_panel), ("old", old_panel)]:
            res = run_model(group_panel, dep_var, rhs_irf)
            if res:
                coef = res["coefs"]["delta_policy_rate"]
                se = res["se"]["delta_policy_rate"]
                pval = res["pvals"]["delta_policy_rate"]
            else:
                coef, se, pval = np.nan, np.nan, np.nan
            irf_records.append({
                "horizon": h,
                "group": group_label,
                "coef": coef,
                "se": se,
                "pval": pval,
            })
            if not np.isnan(coef):
                print(f"  h={h} {group_label}: delta_rate coef = {coef:.3f} "
                      f"(se={se:.3f}, p={pval:.3f})")

    irf_df = pd.DataFrame(irf_records)
    irf_csv_path = TABLE_DIR / "phase4_irf_data.csv"
    irf_df.to_csv(irf_csv_path, index=False)
    print(f"\nSaved IRF data: {irf_csv_path}")

    # ── Plot IRF ─────────────────────────────────────────────────────────
    fig, ax = plt.subplots(figsize=(8, 5))

    for group_label, color, marker in [("old", "firebrick", "s"), ("young", "steelblue", "o")]:
        gdf = irf_df[irf_df["group"] == group_label].copy()
        h_vals = gdf["horizon"].values
        coefs = gdf["coef"].values
        ses = gdf["se"].values
        ci_lo = coefs - 1.96 * ses
        ci_hi = coefs + 1.96 * ses

        ax.plot(h_vals, coefs, marker=marker, color=color, linewidth=2,
                label=f"{group_label.capitalize()} (Z₁ {'Q5' if group_label == 'old' else 'Q1'})")
        ax.fill_between(h_vals, ci_lo, ci_hi, alpha=0.15, color=color)

    ax.axhline(0, color="black", linewidth=0.5, linestyle="--")
    ax.set_xlabel("Horizon (years)", fontsize=12)
    ax.set_ylabel("Cumulative growth response to Δrate", fontsize=12)
    ax.set_title("Impulse Response: Old vs Young OECD Economies", fontsize=13)
    ax.set_xticks(HORIZONS)
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)

    fig_path = FIG_DIR / "phase4_irf_old_vs_young.png"
    fig.savefig(fig_path, dpi=150, bbox_inches="tight")
    plt.close(fig)
    print(f"Saved IRF figure: {fig_path}")

    # ── TABLE 5b: Investment Channel Local Projections ──────────────────
    print("\n── Table 5b: Investment Channel Local Projections ──")

    # Construct cumulative investment growth at each horizon
    # cum_inv_growth_h = h-period rolling sum of YoY Δ(gross_investment_gdp)
    panel["d_inv_gdp"] = panel.groupby("iso3")["gross_investment_gdp"].diff()
    for h in HORIZONS:
        col = f"cum_inv_growth_{h}"
        if col not in panel.columns:
            panel[col] = panel.groupby("iso3")["d_inv_gdp"].transform(
                lambda x: x.rolling(window=h, min_periods=h).sum().shift(-h + 1)
            )

    # Recompute OECD subset with new columns
    oecd = panel[panel["iso3"].isin(OECD_38)].copy()

    rhs_inv = ["delta_policy_rate"] + z_vars + ["Z_1_x_delta_rate"] + controls_growth
    row_vars_inv = rhs_inv

    results_t5b = []
    col_labels_t5b = []

    for h in HORIZONS:
        dep_var = f"cum_inv_growth_{h}"
        # Full sample
        res_full = run_model(panel, dep_var, rhs_inv)
        results_t5b.append(res_full)
        col_labels_t5b.append(f"h={h} Full")
        if res_full:
            c_int = res_full["coefs"].get("Z_1_x_delta_rate", np.nan)
            p_int = res_full["pvals"].get("Z_1_x_delta_rate", np.nan)
            print(f"  h={h} Full: Z_1_x_delta_rate = {c_int:.3f} (p={p_int:.3f})")

        # OECD
        res_oecd = run_model(oecd, dep_var, rhs_inv)
        results_t5b.append(res_oecd)
        col_labels_t5b.append(f"h={h} OECD")
        if res_oecd:
            c_int = res_oecd["coefs"].get("Z_1_x_delta_rate", np.nan)
            p_int = res_oecd["pvals"].get("Z_1_x_delta_rate", np.nan)
            print(f"  h={h} OECD: Z_1_x_delta_rate = {c_int:.3f} (p={p_int:.3f})")

    footer_t5b = [
        ("R²", [f"{r['r2']:.3f}" if r else "" for r in results_t5b]),
        ("N obs", [f"{r['nobs']:,}" if r else "" for r in results_t5b]),
        ("N countries", [f"{r['ncountries']}" if r else "" for r in results_t5b]),
    ]

    md_t5b = build_markdown_table(
        "Table 5b: Investment Channel — Response to Monetary Shocks (Local Projections)",
        row_vars_inv, col_labels_t5b, results_t5b, footer_t5b,
    )
    print("\n" + md_t5b)
    t5b_path = TABLE_DIR / "phase4_table5b_transmission_investment.md"
    t5b_path.write_text(md_t5b)
    print(f"\nSaved: {t5b_path}")

    # ── TABLE 5c: Consumption Channel Local Projections ───────────────
    print("\n── Table 5c: Consumption Channel Local Projections ──")

    # Construct consumption proxy: C/GDP ≈ 100 - I/GDP - CA/GDP
    panel["consumption_proxy_gdp"] = 100.0 - panel["gross_investment_gdp"] - panel["ca_gdp"]

    # Construct cumulative change in consumption proxy at each horizon
    panel["d_cons_gdp"] = panel.groupby("iso3")["consumption_proxy_gdp"].diff()
    for h in HORIZONS:
        col = f"cum_cons_change_{h}"
        if col not in panel.columns:
            panel[col] = panel.groupby("iso3")["d_cons_gdp"].transform(
                lambda x: x.rolling(window=h, min_periods=h).sum().shift(-h + 1)
            )

    # Recompute OECD subset with new columns
    oecd = panel[panel["iso3"].isin(OECD_38)].copy()

    rhs_cons = ["delta_policy_rate"] + z_vars + ["Z_1_x_delta_rate"] + controls_growth
    row_vars_cons = rhs_cons

    results_t5c = []
    col_labels_t5c = []

    for h in HORIZONS:
        dep_var = f"cum_cons_change_{h}"
        # Full sample
        res_full = run_model(panel, dep_var, rhs_cons)
        results_t5c.append(res_full)
        col_labels_t5c.append(f"h={h} Full")
        if res_full:
            c_int = res_full["coefs"].get("Z_1_x_delta_rate", np.nan)
            p_int = res_full["pvals"].get("Z_1_x_delta_rate", np.nan)
            print(f"  h={h} Full: Z_1_x_delta_rate = {c_int:.3f} (p={p_int:.3f})")

        # OECD
        res_oecd = run_model(oecd, dep_var, rhs_cons)
        results_t5c.append(res_oecd)
        col_labels_t5c.append(f"h={h} OECD")
        if res_oecd:
            c_int = res_oecd["coefs"].get("Z_1_x_delta_rate", np.nan)
            p_int = res_oecd["pvals"].get("Z_1_x_delta_rate", np.nan)
            print(f"  h={h} OECD: Z_1_x_delta_rate = {c_int:.3f} (p={p_int:.3f})")

    footer_t5c = [
        ("R²", [f"{r['r2']:.3f}" if r else "" for r in results_t5c]),
        ("N obs", [f"{r['nobs']:,}" if r else "" for r in results_t5c]),
        ("N countries", [f"{r['ncountries']}" if r else "" for r in results_t5c]),
    ]

    md_t5c = build_markdown_table(
        "Table 5c: Consumption Channel — Response to Monetary Shocks (Local Projections)",
        row_vars_cons, col_labels_t5c, results_t5c, footer_t5c,
    )
    print("\n" + md_t5c)
    t5c_path = TABLE_DIR / "phase4_table5c_transmission_consumption.md"
    t5c_path.write_text(md_t5c)
    print(f"\nSaved: {t5c_path}")

    # ── Summary ──────────────────────────────────────────────────────────
    print("\n── Key Test: Does aging weaken monetary transmission? ──")
    # Check h=2 full-sample result for the interaction
    h2_full = results_t5[2]  # h=2 is index 2 (h=1 Full=0, h=1 OECD=1, h=2 Full=2)
    if h2_full and "Z_1_x_delta_rate" in h2_full["coefs"]:
        c = h2_full["coefs"]["Z_1_x_delta_rate"]
        p = h2_full["pvals"]["Z_1_x_delta_rate"]
        direction = "WEAKENS" if c > 0 else "STRENGTHENS"
        sig = "significant" if p < 0.10 else "NOT significant"
        print(f"  h=2 Full: Z_1_x_delta_rate = {c:.3f} (p={p:.3f})")
        print(f"  -> Aging {direction} transmission ({sig})")
    else:
        print("  Could not estimate key interaction at h=2.")

    print("\nPhase 4 complete.")


if __name__ == "__main__":
    main()
